{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "beObUOFyuRjT"
      },
      "source": [
        "##### Copyright 2021 The TF-Agents Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "nQnmcm0oI1Q-"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eutDVTs9aJEL"
      },
      "source": [
        "# Replay Buffers\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/agents/tutorials/5_replay_buffers_tutorial\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003e\n",
        "    View on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/agents/blob/master/docs/tutorials/5_replay_buffers_tutorial.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003e\n",
        "    Run in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/agents/blob/master/docs/tutorials/5_replay_buffers_tutorial.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003e\n",
        "    View source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/agents/docs/tutorials/5_replay_buffers_tutorial.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8aPHF9kXFggA"
      },
      "source": [
        "## Introduction\n",
        "\n",
        "Reinforcement learning algorithms use replay buffers to store trajectories of experience when executing a policy in an environment. During training, replay buffers are queried for a subset of the trajectories (either a sequential subset or a sample) to \"replay\" the agent's experience.\n",
        "\n",
        "In this colab, we explore two types of replay buffers: python-backed and tensorflow-backed, sharing a common API. In the following sections, we describe the API, each of the buffer implementations and how to use them during data collection training.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1uSlqYgvaG9b"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GztmUpWKZ7kq"
      },
      "source": [
        "Install tf-agents if you haven't already."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TnE2CgilrngG"
      },
      "outputs": [],
      "source": [
        "!pip install tf-agents\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "whYNP894FSkA"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "\n",
        "from tf_agents import specs\n",
        "from tf_agents.agents.dqn import dqn_agent\n",
        "from tf_agents.drivers import dynamic_step_driver\n",
        "from tf_agents.environments import suite_gym\n",
        "from tf_agents.environments import tf_py_environment\n",
        "from tf_agents.networks import q_network\n",
        "from tf_agents.replay_buffers import py_uniform_replay_buffer\n",
        "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n",
        "from tf_agents.specs import tensor_spec\n",
        "from tf_agents.trajectories import time_step"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xcQWclL9FpZl"
      },
      "source": [
        "## Replay Buffer API\n",
        "\n",
        "The Replay Buffer class has the following definition and methods:\n",
        "\n",
        "```python\n",
        "class ReplayBuffer(tf.Module):\n",
        "  \"\"\"Abstract base class for TF-Agents replay buffer.\"\"\"\n",
        "\n",
        "  def __init__(self, data_spec, capacity):\n",
        "    \"\"\"Initializes the replay buffer.\n",
        "\n",
        "    Args:\n",
        "      data_spec: A spec or a list/tuple/nest of specs describing\n",
        "        a single item that can be stored in this buffer\n",
        "      capacity: number of elements that the replay buffer can hold.\n",
        "    \"\"\"\n",
        "\n",
        "  @property\n",
        "  def data_spec(self):\n",
        "    \"\"\"Returns the spec for items in the replay buffer.\"\"\"\n",
        "\n",
        "  @property\n",
        "  def capacity(self):\n",
        "    \"\"\"Returns the capacity of the replay buffer.\"\"\"\n",
        "\n",
        "  def add_batch(self, items):\n",
        "    \"\"\"Adds a batch of items to the replay buffer.\"\"\"\n",
        "\n",
        "  def get_next(self,\n",
        "               sample_batch_size=None,\n",
        "               num_steps=None,\n",
        "               time_stacked=True):\n",
        "    \"\"\"Returns an item or batch of items from the buffer.\"\"\"\n",
        "\n",
        "  def as_dataset(self,\n",
        "                 sample_batch_size=None,\n",
        "                 num_steps=None,\n",
        "                 num_parallel_calls=None):\n",
        "    \"\"\"Creates and returns a dataset that returns entries from the buffer.\"\"\"\n",
        "\n",
        "\n",
        "  def gather_all(self):\n",
        "    \"\"\"Returns all the items in buffer.\"\"\"\n",
        "    return self._gather_all()\n",
        "\n",
        "  def clear(self):\n",
        "    \"\"\"Resets the contents of replay buffer\"\"\"\n",
        "\n",
        "```\n",
        "\n",
        "Note that when the replay buffer object is initialized, it requires the `data_spec` of the elements that it will store. This spec corresponds to the `TensorSpec` of trajectory elements that will be added to the buffer. This spec is usually acquired by looking at an agent's `agent.collect_data_spec` which defines the shapes, types, and structures expected by the agent when training (more on that later)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X3Yrxg36Ik1x"
      },
      "source": [
        "## TFUniformReplayBuffer\n",
        "\n",
        "`TFUniformReplayBuffer` is the most commonly used replay buffer in TF-Agents, thus we will use it in our tutorial here. In `TFUniformReplayBuffer` the backing buffer storage is done by tensorflow variables and thus is part of the compute graph. \n",
        "\n",
        "The buffer stores batches of elements and has a maximum capacity `max_length` elements per batch segment. Thus, the total buffer capacity is `batch_size` x `max_length` elements. The elements stored in the buffer must all have a matching data spec. When the replay buffer is used for data collection, the spec is the agent's collect data spec.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lYk-bn2taXlw"
      },
      "source": [
        "### Creating the buffer:\n",
        "To create a `TFUniformReplayBuffer` we pass in:\n",
        "1. the spec of the data elements that the buffer will store\n",
        "2. the `batch size` corresponding to the batch size of the buffer \n",
        "3. the `max_length` number of elements per batch segment\n",
        "\n",
        "Here is an example of creating a `TFUniformReplayBuffer` with sample data specs, `batch_size` 32 and `max_length` 1000."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Dj4_-77_5ExP"
      },
      "outputs": [],
      "source": [
        "data_spec =  (\n",
        "        tf.TensorSpec([3], tf.float32, 'action'),\n",
        "        (\n",
        "            tf.TensorSpec([5], tf.float32, 'lidar'),\n",
        "            tf.TensorSpec([3, 2], tf.float32, 'camera')\n",
        "        )\n",
        ")\n",
        "\n",
        "batch_size = 32\n",
        "max_length = 1000\n",
        "\n",
        "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n",
        "    data_spec,\n",
        "    batch_size=batch_size,\n",
        "    max_length=max_length)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XB8rOw5ATDD2"
      },
      "source": [
        "### Writing to the buffer:\n",
        "To add elements to the replay buffer, we use the `add_batch(items)` method where `items` is a list/tuple/nest of tensors representing the batch of items to be added to the buffer. Each element of `items` must have an outer dimension equal `batch_size` and the remaining dimensions must adhere to the data spec of the item (same as the data specs passed to the replay buffer constructor). \n",
        "\n",
        "Here's an example of adding a batch of items \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nOvkp4vJhBOT"
      },
      "outputs": [],
      "source": [
        "action = tf.constant(1 * np.ones(\n",
        "    data_spec[0].shape.as_list(), dtype=np.float32))\n",
        "lidar = tf.constant(\n",
        "    2 * np.ones(data_spec[1][0].shape.as_list(), dtype=np.float32))\n",
        "camera = tf.constant(\n",
        "    3 * np.ones(data_spec[1][1].shape.as_list(), dtype=np.float32))\n",
        "  \n",
        "values = (action, (lidar, camera))\n",
        "values_batched = tf.nest.map_structure(lambda t: tf.stack([t] * batch_size),\n",
        "                                       values)\n",
        "  \n",
        "replay_buffer.add_batch(values_batched)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "smnVAxHghKly"
      },
      "source": [
        "### Reading from the buffer\n",
        "\n",
        "There are three ways to read data from the `TFUniformReplayBuffer`:\n",
        "\n",
        "1. `get_next()` - returns one sample from the buffer. The sample batch size and number of timesteps returned can be specified via arguments to this method.\n",
        "2. `as_dataset()` - returns the replay buffer as a `tf.data.Dataset`. One can then create a dataset iterator and iterate through the samples of the items in the buffer.\n",
        "3. `gather_all()` - returns all the items in the buffer as a Tensor with shape `[batch, time, data_spec]`\n",
        "\n",
        "Below are examples of how to read from the replay buffer using each of these methods:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IlQ1eGhohM3M"
      },
      "outputs": [],
      "source": [
        "# add more items to the buffer before reading\n",
        "for _ in range(5):\n",
        "  replay_buffer.add_batch(values_batched)\n",
        "\n",
        "# Get one sample from the replay buffer with batch size 10 and 1 timestep:\n",
        "\n",
        "sample = replay_buffer.get_next(sample_batch_size=10, num_steps=1)\n",
        "\n",
        "# Convert the replay buffer to a tf.data.Dataset and iterate through it\n",
        "dataset = replay_buffer.as_dataset(\n",
        "    sample_batch_size=4,\n",
        "    num_steps=2)\n",
        "\n",
        "iterator = iter(dataset)\n",
        "print(\"Iterator trajectories:\")\n",
        "trajectories = []\n",
        "for _ in range(3):\n",
        "  t, _ = next(iterator)\n",
        "  trajectories.append(t)\n",
        "  \n",
        "print(tf.nest.map_structure(lambda t: t.shape, trajectories))\n",
        "\n",
        "# Read all elements in the replay buffer:\n",
        "trajectories = replay_buffer.gather_all()\n",
        "\n",
        "print(\"Trajectories from gather all:\")\n",
        "print(tf.nest.map_structure(lambda t: t.shape, trajectories))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BcS49HrNF34W"
      },
      "source": [
        "## PyUniformReplayBuffer\n",
        "`PyUniformReplayBuffer`  has the same functionaly as the `TFUniformReplayBuffer` but instead of tf variables, its data is stored in numpy arrays. This buffer can be used for out-of-graph data collection. Having the backing storage in numpy may make it easier for some applications to do data manipulation (such as indexing for updating priorities) without using Tensorflow variables. However, this implementation won't have the benefit of graph optimizations with Tensorflow. \n",
        "\n",
        "Below is an example of instantiating a `PyUniformReplayBuffer` from the agent's policy trajectory specs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F4neLPpL25wI"
      },
      "outputs": [],
      "source": [
        "replay_buffer_capacity = 1000*32 # same capacity as the TFUniformReplayBuffer\n",
        "\n",
        "py_replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(\n",
        "    capacity=replay_buffer_capacity,\n",
        "    data_spec=tensor_spec.to_nest_array_spec(data_spec))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9V7DEcB8IeiQ"
      },
      "source": [
        "## Using replay buffers during training\n",
        "Now that we know how to create a replay buffer, write items to it and read from it, we can use it to store trajectories during training of our agents. \n",
        "\n",
        "### Data collection\n",
        "First, let's look at how to use the replay buffer during data collection.\n",
        "\n",
        "In TF-Agents we use a `Driver` (see the Driver tutorial for more details) to collect experience in an environment. To use a `Driver`, we specify an `Observer` that is a function for the `Driver` to execute when it receives a trajectory. \n",
        "\n",
        "Thus, to add trajectory elements to the replay buffer, we add an observer that calls `add_batch(items)` to add a batch of items on the replay buffer. \n",
        "\n",
        "Below is an example of this with `TFUniformReplayBuffer`. We first create an environment, a network and an agent. Then we create a `TFUniformReplayBuffer`. Note that the specs of the trajectory elements in the replay buffer are equal to the agent's collect data spec. We then set its `add_batch` method as the observer for the driver that will do the data collect during our training:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pCbTDO3Z5UCS"
      },
      "outputs": [],
      "source": [
        "env = suite_gym.load('CartPole-v0')\n",
        "tf_env = tf_py_environment.TFPyEnvironment(env)\n",
        "\n",
        "q_net = q_network.QNetwork(\n",
        "    tf_env.time_step_spec().observation,\n",
        "    tf_env.action_spec(),\n",
        "    fc_layer_params=(100,))\n",
        "\n",
        "agent = dqn_agent.DqnAgent(\n",
        "    tf_env.time_step_spec(),\n",
        "    tf_env.action_spec(),\n",
        "    q_network=q_net,\n",
        "    optimizer=tf.compat.v1.train.AdamOptimizer(0.001))\n",
        "\n",
        "replay_buffer_capacity = 1000\n",
        "\n",
        "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n",
        "    agent.collect_data_spec,\n",
        "    batch_size=tf_env.batch_size,\n",
        "    max_length=replay_buffer_capacity)\n",
        "\n",
        "# Add an observer that adds to the replay buffer:\n",
        "replay_observer = [replay_buffer.add_batch]\n",
        "\n",
        "collect_steps_per_iteration = 10\n",
        "collect_op = dynamic_step_driver.DynamicStepDriver(\n",
        "  tf_env,\n",
        "  agent.collect_policy,\n",
        "  observers=replay_observer,\n",
        "  num_steps=collect_steps_per_iteration).run()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "huGCDbO4GAF1"
      },
      "source": [
        "### Reading data for a train step\n",
        "\n",
        "After adding trajectory elements to the replay buffer, we can read batches of trajectories from the replay buffer to use as input data for a train step.\n",
        "\n",
        "Here is an example of how to train on trajectories from the replay buffer in a training loop: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gg8SUyXXnSMr"
      },
      "outputs": [],
      "source": [
        "# Read the replay buffer as a Dataset,\n",
        "# read batches of 4 elements, each with 2 timesteps:\n",
        "dataset = replay_buffer.as_dataset(\n",
        "    sample_batch_size=4,\n",
        "    num_steps=2)\n",
        "\n",
        "iterator = iter(dataset)\n",
        "\n",
        "num_train_steps = 10\n",
        "\n",
        "for _ in range(num_train_steps):\n",
        "  trajectories, _ = next(iterator)\n",
        "  loss = agent.train(experience=trajectories)\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "TF-Agents Replay Buffers Tutorial.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
